Skip to content

feat: shared-engine refactor + Nemotron-Nano-30B GB10 prefill/decode optimizations (~5×)#20

Draft
TheTom wants to merge 163 commits into
devfrom
tom/feat/cuda-hip-vulkan-backends
Draft

feat: shared-engine refactor + Nemotron-Nano-30B GB10 prefill/decode optimizations (~5×)#20
TheTom wants to merge 163 commits into
devfrom
tom/feat/cuda-hip-vulkan-backends

Conversation

@TheTom

@TheTom TheTom commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

First time here? Read the TL;DR + "where to start" map below — you don't need to read all 104 commits. Big draft (159 files, ~29K lines) opened for visibility.

TL;DR

Two things, both backend-gated so the default path can't regress:

  1. Shared-engine refactor — each model's forward is written once and runs on both CUDA (server GPU) and Metal (Apple) backends; +ffai-vulkan for AMD.
  2. A deep prefill/decode optimization pass on Nemotron-Nano-30B (Mamba2 + 128-expert-MoE hybrid) on a single GB10.

Prefill result so far (S=2048, GB10, byte-matched precision, argmax-gated)

context depth start of pass now
d0 187 tok/s ~411–446
d32768 103 (collapsing) ~334 (flat)

The curve went from collapsing to flat ~370 across d0→d32768 — ~2.3× at d0, ~3.4× at deep context. Every step correctness-gated (logit cosine + shared-top-5, not just argmax).

The levers (each gated, each measured)

  • Conv double-shift fix — a real correctness bug: the all-device Mamba decode step shifted the causal-conv ring twice, corrupting multi-token decode past token 0 (invisible at pos 0, so Phase-1 missed it).
  • SSD matmul scan (NEMOTRON_SSD_MATMUL) — Mamba2 state-space-duality chunked-matmul via batched cuBLAS; ssm_scan stage 26%→5% (6.8×), +36% e2e @ d0. Scan-level cosine 1.000000.
  • Tensor-core flash-attention (sdpa_multi_tc, depth auto-select) — cuBLAS online-softmax, GQA stride-0 fan-out; sdpa stage 86%→30% @ d32768 (up to 14×), +59% deep-context e2e.
  • Launch-geometry fix — six prefill elementwise kernels were dispatched block:[1,1,1] (one thread/block!); packed into real blocks → +7–13%, byte-identical.
  • Batched prefill on Metal (backend-gated) — first Nemotron-30B prefill on Apple Silicon (M5 Max, 33.9 tok/s @ S=2048), resident weights, EXACT-MATCH vs reference. CUDA path untouched. Adds an MLX-4bit loader.
  • Earlier: cuBLAS tensor-core GEMM escape hatch, fully on-device MoE, batched rope/kv/conv, W4A16 MoE GEMM family, deep-context KV-cap IMA fix.

Honestly scoped

  • The MoE expert GEMM is settled at per-expert cuBLAS (~27 TFLOP/s) — three faster-kernel attempts (TC-grouped, on-device, grouped-Marlin) all proven slower; it's memory-pipe-bound, would need a multi-day CUTLASS-grade rewrite. Profiling shows the next real win is the GPU-idle gap (per-forward cuMemAlloc/cuMemFree device-syncs → a caching allocator, in progress), not the GEMM.
  • vs a mature serving engine's prefill (~6,400 t/s) we're not there yet — but the curve is flat and every gain is measurement-backed.

Where to start reviewing

  • Refactor: ffai-modeltests/ (forward written once), ffai-runtime/, ffai-loader/ (+MLX dequant).
  • Prefill levers: ffai-ops/src/lib.rs (kernels: sdpa_multi_tc, ssm_prefill_scan_ssd, launch-geometry) + ffai-modeltests/src/lib.rs forward_batched (gates + backend select).
  • Big line-count is per-backend test/bench files — skim.

Build/validate

cargo test --release -p ffai-cuda --features cuda --test nemotron; per-op profiler NEMOTRON_PROFILE=1; correctness NEMOTRON_PREFILL_CHECK=1. Pairs with the metaltile kernel PR. WIP — happy to walk through live.

TheTom added 30 commits May 27, 2026 08:57
… regressions

Ports the canonical TQ+ kld_vs_baseline harness from
`bench-tq+/harness/kld_vs_baseline.py` (in /Users/tom/local_llms/llama.cpp)
to FFAI Swift. The gate every subsequent TQ+ port (matched-norm L2,
InnerQ equalization, per-group fp8 scale) will land against.

Modules:
* Sources/FFAI/Quality/KLDivergence.swift — per-position + aggregate
  metrics on raw logit pairs. Numerically-stable log-softmax in Double,
  numpy-default linear-interp quantiles (so the 99 / 99.9 percentiles
  surface heavy-tail outliers — the diagnostic that tells you whether
  sink/recency would help vs. position-agnostic codec improvement).
  Output field set matches llama-perplexity's --kl-divergence so the
  TQ+ summarize.py scripts can ingest FFAI runs.
* Sources/FFAI/Quality/LogitsEmitter.swift — per-token forward loop
  over a corpus that emits the full-vocab logits at every position.
  Uses Tensor.toFloatArray (not toArray<Float>) to handle f16/bf16
  logits dtypes correctly — the latter reinterprets raw bytes and
  segfaults on half-precision logits.

Tests:
* 8 unit tests in Tests/FFAITests/Quality/KLDivergenceTests.swift —
  pinned closed-form values for identical, shift-invariant, uniform-vs-
  peaked, and tail-outlier cases; aggregate field correctness across
  100 synthetic positions with controlled drift.
* 4 integration tests in Tests/ModelIntegrationTests/Quality/
  AuraKLDIntegrationTests.swift on local Qwen3-0.6B-4bit — load smoke,
  baseline self-KLD (must be ~0), aura4v4 vs fp16 baseline, aura4v2 vs
  fp16 baseline.

First measured AURA quality on FFAI (Qwen3-0.6B-4bit, 64-token diverse
sample prompt, M5 Max):

  aura4v4 (default):   mean_kld=1.2414 same_top=47.54% max=7.70
  aura4v2 (production): mean_kld=1.9307 same_top=42.62% max=9.69

For reference, canonical TQ+ on llama.cpp gets same_top > 99% at
comparable bit-widths. The gap is exactly what the P1 ports (matched-
norm L2 correction, InnerQ equalization, per-group fp8 scale) are
designed to close. Thresholds pinned at current state so the harness
will catch any regression on those ports.
Adds 4 more bench cases to AuraKLDIntegrationTests so the curve has
data points for every supported AURA scheme bit-width — needed before
prioritising TQ+ ports:

* aura3v3 — skipped (precondition trip in Ops.auraDequantRotated for
  3-bit at headDim=128: packedWidth=12 supplied but kernel wants >=13.
  Real bug to file separately, but not the priority right now).
* aura2v2 — characterises the bottom of the curve.
* aura8v4 — TQ+'s canonical production recipe (high-bit K, aggressive
  V). Demonstrates the "Why V is Free, K is Everything" thesis
  empirically.

Measured on Qwen3-0.6B-4bit / M5 Max (one-shot, 64-token sample
prompt):

  fp16 baseline:       mean_kld=0.0000  same_top=100.00%
  aura8v8 (8-bit sym): mean_kld=0.0052  same_top= 95.08%
  aura8v4 (TQ+ recipe):mean_kld=0.0285  same_top= 88.52%
  aura4v4 (4-bit sym): mean_kld=1.2414  same_top= 47.54%
  aura4v2 (asym):      mean_kld=1.9307  same_top= 42.62%
  aura2v2 (2-bit sym): mean_kld=4.6193  same_top= 13.11%

Takeaways:
  * K bit-width dominates attention quality. Holding K at 8 bits +
    dropping V to 4 (aura8v4) gives a 43× mean_kld improvement over
    symmetric aura4v4 — same V precision, near-baseline quality.
  * AURA + TQ+ centroids are byte-identical at 2-bit + 3-bit
    (verified vs llama.cpp ggml-turbo-quant.c CENTROIDS_*). The
    quality gap is not codebook.
  * Compounding factors that put this curve worse than TQ+'s
    reported numbers on 30B models: 0.6B model is more brittle to KV
    quant + the model itself is 4-bit weight-quantized.

Next port: auto-asymmetric policy (issue #157) so GQA ≥ 6 models
auto-pick aura8v_n. Production-shape Qwen3.6-A3B (GQA=8) would auto-
engage without user opt-in.
…esets

Ports canonical TQ+'s TURBO_AUTO_ASYMMETRIC behavior to AURA. When
the model's GQA fan-out is ≥ 6 (Qwen3.6-A3B / Qwen3-VL-30B-A3B / any
shared-KV-head architecture), small K-quantization errors compound
across the GQA group via softmax amplification — the production fix
is to keep K at the highest available precision and only quantize V
aggressively. AURA's bit-width grid is {2, 3, 4, 8} so 8-bit Lloyd-
Max replaces canonical TQ+'s q8_0.

Empirical motivation (Qwen3-0.6B-4bit / FFAI KLD harness):

  aura8v8:  mean_kld=0.005, same_top=95%
  aura8v4:  mean_kld=0.029, same_top=89%   ← TQ+ production recipe
  aura4v4:  mean_kld=1.24,  same_top=48%
  aura2v2:  mean_kld=4.62,  same_top=13%

Holding K at 8 bits and dropping V to 4 (aura8v4) gives a 43× mean_kld
improvement over symmetric aura4v4. K bit-width dominates attention
quality; V precision is roughly free.

Adds:
  * `AURAScheme.aura8v4` + `aura8v2` named presets (parse: aura8v4 /
    aura8v2 strings).
  * `AURAScheme.autoAsymmetric(requested:gqaFactor:)` static resolver.
    GQA < 6 → return requested unchanged. GQA ≥ 6 and keyBits < 8 →
    bump keyBits to 8. Default ON; `FFAI_AURA_AUTO_ASYM=0` disables.
    Threshold of 6 matches canonical TQ+ at
    `~/local_llms/llama.cpp/src/llama-kv-cache.cpp`.
  * Applied at the two FFAI AURA cache construction sites
    (`Qwen3Model.makeKVCache`, `LlamaModel.makeKVCache`).
  * 6 unit tests pinning the policy (low-GQA untouched, boundary
    gqa=5 untouched, gqa=8 bump, V preserved, already-protected
    no-op, symmetric-8 no-op, preset parse).

Regression check: Qwen3-0.6B-4bit (gqa=2, below threshold) — aura4v4
KLD is byte-identical (mean_kld=1.2414, same_top=47.54%). Low-GQA
models unaffected.

Production-shape Qwen3.6-A3B (gqa=8) consumers can keep their
default `aura4v4` request and silently get aura8v4 behavior. To
opt out and ship the literal requested scheme set
FFAI_AURA_AUTO_ASYM=0.
…c range

Adds two focused round-trip tests asserting that AURA's encode/dequant
pair preserves the input row's L2 norm to within fp16 noise (rel_err
< 1e-3) across three magnitudes (0.25 / 1.0 / 4.0) at both 4-bit and
8-bit. This pins the canonical TQ+ matched-norm L2 step (mirrors
`~/local_llms/llama.cpp/ggml/src/ggml-turbo-quant.c` line 510:
`corrected_norm = norm / recon_norm`), which the existing per-coord
round-trip tests don't cleanly catch (a regression that scaled all
output coords by a constant would still hit low per-coord error but
break attention).

Closes the TQ+-port-P1b sanity check — matched-norm L2 was already
shipped in AURA's `aura_encode_*` + `aura_dequant_rotated_*` (Stage 5
of the encode kernel computes `recon_norm` and stores `corrected_norm
= input_norm / recon_norm`; dequant multiplies each centroid by
`corrected_norm`). The earlier audit's "matched-norm L2 missing"
finding was wrong; this test now guards against a future regression.
…layer wiring deferred)

Adds the FFAI Ops surface for AURA's compressed flash decode path —
maps to metaltile's existing `aura_flash_sdpa_kb4_v{b2,b4}_d128_{f32,
f16,bf16}` kernels. Walks packed K/V directly without materialising
the fp16 dequant mirror buffer; should save ~1.8 GB / decode step at
Qwen3-1.7B / maxSeq=32K.

Scope:
* Ops.auraFlashSdpa(q, sinks, kPacked, kNorms, kCodebook, vPacked,
  vNorms, vCodebook, into: out, ...) — 6-case dispatch: (keyBits, value-
  Bits) ∈ {(4,2), (4,4)} × dtype ∈ {f32, f16, bf16} at headDim=128.
  Casts q to f32 internally when needed (kernel pins q_rot dtype).
* Ops.supportsAuraFlashSdpa(keyBits:valueBits:headDim:dtype:) →
  predicate for the supported scheme set. d=64 metaltile kernels
  exist but their Ops.swift dispatch isn't wired in this commit;
  gate to d=128.

What this does NOT include:
* Model-layer call site (Qwen3Layer.forward). First wiring attempt
  produced mean_kld=14.30 / same_top=0% on aura4v4 (Qwen3-0.6B-4bit)
  vs the dequant-mirror's 1.24 / 47.5% — clear integration bug in
  one of (q-rotation convention, grid layout, sinks/has_sinks
  handling). Reverted to keep `AURADecodePath.compressed` silently
  downgrading to `.dequantMirror`. Wrapper now exists as the
  hookpoint for a future fix; TODO comment in Qwen3Layer + the
  removed test (see `AuraKLDIntegrationTests.swift`) document the
  outstanding work.

Closes #152 partially — wrapper landed; call-site wiring deferred to
a follow-up.
Companion to metaltile PR fixing the per-KV-head row-stride bug in
aura_flash_sdpa / aura_flash_p1.

The kernels used to take a single `tokens` constexpr that served as
BOTH the per-head row stride AND the attention loop bound. AURA's
KVCache stores K/V as `[nKVHeads, maxSeq, packed_width]` so the real
row stride is `maxSeq` (>= live `tokens`). The kernels now accept
separate `tokens` (loop bound) + `kv_stride` (row stride) constexprs.

This commit adds a required `kvStride: Int` parameter to
`Ops.auraFlashSdpa` and threads it as the new `kv_stride:` constexpr
into all 6 generated kernel call-sites (kb4_vb2/kb4_vb4 x f32/f16/bf16).
Asserts kvStride >= liveLength.

Callers MUST pass `kvStride = cache.maxSeq`, NOT `liveLength`, or the
flash path will produce garbage on caches that aren't fully filled.

Model-layer wiring (Qwen3Layer.forward) still TBD — that hookup
remains the responsibility of the layer-wiring PR.
Pre-scale (1/√headDim) is a kernel contract — aura_flash_sdpa.rs header
states 'q_rot is WHT-rotated AND pre-scaled by caller'. Earlier wiring
attempt skipped this, producing mean_kld=14.3 on aura4v4. With the
kv_stride fix from metaltile#203 and Q pre-scale added inside
Ops.auraFlashSdpa, compressed flash now matches dequant-mirror:

  aura4v4 dequant-mirror: mean_kld=1.2414 same_top=0.4754
  aura4v4 compressed:     mean_kld=1.1880 same_top=0.5246  ✓
  aura4v2 dequant-mirror: mean_kld=1.9307 same_top=0.4262
  aura4v2 compressed:     mean_kld=2.0526 same_top=0.3115  (small 2-bit-V gap)

Wires the .compressed decode path in Qwen3Layer.forward when the cache
is AURAQuantizedKVCache + Ops.supportsAuraFlashSdpa is true. Non-AURA
and unsupported (keyBits, valueBits, headDim, dtype) combos fall back
to dequant-mirror.

Adds .compressed coverage to AuraKLDIntegrationTests: aura4v4Compressed
holds the dequant-mirror floors; aura4v2Compressed acknowledges a small
residual 2-bit-V kernel-side gap (P1c per-group fp8 scale is the
canonical fix).
Trivial nits from @ekryski's PR #15 review:

* Copyright headers on the 4 new files updated to credit both authors
  (`Eric Kryski (@ekryski) and Tom Turney (@TheTom)`), matching the
  established convention from d2367da.

* Auto-asymmetric policy is now opt-in. AURAScheme.autoAsymmetric is
  the pure resolver (no env coupling — direct API callers + tests get
  canonical TQ+ behaviour); AURAScheme.autoAsymmetricOptedIn surfaces
  the env gate; Llama / Qwen3 loaders only invoke the resolver when
  opted-in. Default is OFF; FFAI_AURA_AUTO_ASYM=1 enables. Matches
  @ekryski's 'no magic by default' stance. A per-load LoadOptions
  flag will replace the env knob in a follow-up.

Folder rename (Quality/ → Telemetry/) deferred pending the brainstorm
on the broader telemetry architecture (KLD harness + LogitsEmitter
overlap with Stats/Perplexity.swift + Sampling.swift + GenerationStats
— posted a sketch on the PR thread).
…ribution

Wraps the four primary hot-path entry points in Profile.signpost(...)
blocks so Metal kernel dispatches nest under the right phase span when
running under Instruments / xctrace at profiling level 2:

  - Qwen35MoEModel.forward — model.embed, model.layer_loop,
    model.final_norm_lm_head
  - Qwen35AttentionMixer.forward — attn.forward
  - Qwen35GDNMixer.forward — gdn.forward
  - MoELayer.decode — moe.decode

Profile.signpost is zero-cost when Profile.shared.level < .signposts
(default off), so no overhead at production. Verified by bench:
prefill 197.19 → 196.33 tps, decode 92.16 → 91.41 tps (within noise).

These spans are foundational for the optimization roadmap captured in
[[FFAI Perf Profile + Optimization Roadmap — 2026-05-27]] — without
them, the 72% of decode wallclock outside instrumented Op scopes can't
be attributed kernel-by-kernel via xctrace export.

Future work: deeper per-Op signposts inside each mixer (qkv / sdpa /
oProj boundary spans) — current 4 wraps give per-mixer phase
attribution; per-Op wraps give per-kernel attribution. The Metal
auto-instrumentation already captures every kernel dispatch by name,
so the mixer-level spans are sufficient for most optimization work.
…closes -58%→-9% gap)

End-to-end FFAI side of the AURA dtype unification (metaltile sigs
0e4cb1a + 3fdadb3, PR 0xClandestine/metaltile#212). Replaces the
intermediate `.dequantMirror` default (originally bf16'd as a stopgap
because the single-pass `aura_flash_sdpa` kernel starved the GPU with
one simdgroup per query) with the right architecture: a single source
of truth in the activation dtype, and the token-parallel FA-2 kernel
pair.

## Cache schema — single source of truth (AURAQuantizedKVCache)
- kNorms / vNorms allocated in `dtype` (was f32-only).
- kCodebook / vCodebook allocated in `dtype` (was f32-only).
- kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare
  precision matters and they never reach the decode kernels.
- encodePerHead view stride now keys off `dtype.byteSize`, not a
  hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests`
  the moment a non-f32 cache hit the encode path).

## Loaders (LlamaText / Qwen3Text)
- New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side
  conversion helper covers all three float dtypes (f32 / f16 / bf16).
- `AURACodebook.boundariesTensor(...)` mirrors the helper for the
  encoder-only boundaries buffer.
- Both Qwen3 and Llama AURA cache builders use the helpers — no more
  copy-pasted f32 `Tensor.empty + copyIn` block per loader.

## Ops surface
- `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook
  requirement; everything must now match `out.dtype` (the activation
  dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer
  to an activation-dtype scratch + activation-dtype scale buffer.
- `AuraFlashScratchCache` keys both scratches on (count, dtype) — was
  keyed on `count` alone with f32 hardcoded. Adds a `partials(...)`
  scratch cache for the 2-pass partials triple.
- `Ops.auraEncode` + `Ops.auraDequantRotated` preconditions drop the
  f32-norms-and-codebook requirement; the dequant-mirror path flows
  through T now too.
- New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` +
  `aura_flash_pass2` for token-parallel FA-2 over the compressed cache.
  Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`).
- New `Ops.supportsAuraFlashSdpa2Pass` predicate.

## Qwen3Layer.forward
- Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to
  `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted
  for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128).
- Block size 64 — matches the dense `sdpaDecode2Pass` per-block work
  size and saturates the M5 Max class around liveLength ≈ 4K.

## Default — back to `.compressed`
`LoadOptions.auraDecodePath` defaults to `.compressed`. Matches
@ekryski's stance from the PR review — true compressed attention is
FFAI's quantized-attention story and should be the default-path users
load into. The dtype unification + 2-pass FA-2 closes the perf gap that
made the original `.dequantMirror` flip necessary.

## Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness)
| scheme                  | mean_kld | same_top |
|-------------------------|---------:|---------:|
| aura4v4 dequant-mirror  |     1.42 |      43% |
| aura4v4 2-pass flash    | **1.40** | **48%**  |
| aura4v2 2-pass flash    |     1.69 |      44% |
| aura8v4 (TQ+ recipe)    |    0.018 |      93% |

2-pass compressed flash matches (slightly beats) dequant-mirror on
aura4v4. KLD harness regression gate green for all schemes.

## Perf (M5 Max, Qwen3-0.6B-4bit decode tps, 5-run median)
| KV   | dequant-mirror | compressed (2-pass) | gap    | gap pre-unification |
|------|----------------|---------------------|--------|---------------------|
| 64   | 80.88          | 71.62               | -11.4% |              -15.7% |
| 256  | 77.14          | 67.71               | -12.2% |          **-43.7%** |
| 1024 | 46.87          | 42.73               |  -8.8% |          **-57.8%** |

Long-KV gap collapsed from -57.8% → -8.8%. Single-digit perf delta vs
dequant-mirror with 1.88× cache memory savings preserved (aura4v4 @
maxSeq=4096: 4352 KiB packed+norms vs 8192 KiB mirror).

## Why the C++ canonical pattern is safe
The fp16-stored norms / f32-at-use pattern this PR adopts mirrors the
production C++ `llama.cpp` TQ+ fork — commit b696c5da1 in that fork
shipped fp16 centroid LUTs + float-norm broadcast with measured zero
PPL impact ("Constant half LUT + float norm broadcast remains the
fastest approach on Apple Silicon", ggml-metal.metal:776). Internal
kernel arithmetic stays in f32 via cast-at-load; only the storage
narrows.

## Pass-2 dispatch shape note
`aura_flash_pass2`'s kernel header says `tg = (32, 1, 1) per q_idx`,
which means `q_idx = tgid_x`. Wrapper dispatches raw threads
`[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x,
each running 32 lanes; matches the metaltile end-to-end test's grid
shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue
of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG,
i.e. every Q head's reduce reads q_idx=0's partials — produced
garbage same_top=0.0 / mean_kld=12+ output before the fix. Worth a
comment in the wrapper (added).

## Bench infra retained from the original perf pass
`AuraFlashScratchCache` (process-wide static, NSLock-guarded) memoizes
the Q scratch + scale buffer per (shape, dtype, scale) tuple. The
`AuraDecodeBenchIntegrationTests` side-by-side bench grid + memory
footprint asserter (KV=64 / 256 / 1024 + maxSeq=4096) are also kept
as regression catchers.
…n compressed)

`AuraFlashScratchCache.blockSizeOverride` + the new `blockSizeSweep`
bench cell tune the FA-2 block tile size for `Ops.auraFlashSdpa2Pass`.

Sweep results (M5 Max, Qwen3-0.6B-4bit aura4v4, 3-run / 24-step median):

  KV \ bs    32       64      128      256
  KV=256     72.42   68.62   59.99   50.71
  KV=1024    56.16   54.81   49.65   42.98

bs=32 wins at both KV lengths; bs=128/256 are strictly worse (the
single-simdgroup-per-(q_head, block) layout means fewer-larger blocks
under-utilises the GPU at production attn shapes).

Same direction confirmed in the full 5-run / 32-step bench:

  KV=64    bs=64 71.62 → bs=32 73.13 tps  (+2.1%)
  KV=256   bs=64 67.71 → bs=32 69.85 tps  (+3.2%)
  KV=1024  bs=64 42.73 → bs=32 44.37 tps  (+3.8%)

Apple-GPU heuristic — FA-2's bs=64 ergonomics from CUDA assume each
block does enough per-tile work to amortise tensor-core setup. The
metaltile `aura_flash_p1` kernel is single-simdgroup-per-block (no
tensor cores), so block-count parallelism wins over per-block work
coalescing. Same effect Eric documented in the C++ TQ+ fork's
`ggml-metal.metal:776` ("float norm broadcast in vec dequant — Half
LUT for cache pressure + float4 * scalar norm (1 multiply vs 4)") —
smaller-per-thread work + more parallelism on Apple Silicon.

Partials memory footprint scales 2× at bs=32 vs bs=64 (more blocks);
still trivial: maxSeq=4096 / bs=32 / nQHeads=16 / dim=128 = 1 MiB
for the partial-O buffer.

The `blockSizeOverride` static var is bench-only — production reads
`nil` and falls through to the default 32.
Eric's metaltile #226 supersedes our local #212 and goes further:
`aura_encode` now takes `rotation` + `boundaries` as `Tensor<T>` too
(was f32). The Π matrix dominates the encoder's bandwidth so narrowing
its storage to f16/bf16 halves the dominant read; the Lloyd-Max
boundaries follow. f32 accumulation inside the encoder is kept — only
storage narrows.

## FFAI changes

- `AURACodebook.boundariesTensor(...)` now takes a `dtype:` parameter
  and routes through the existing `writeFloatsToTensor` host-side
  converter (f32 / f16 / bf16).
- `AURAQuantizedKVCache` preconditions: `kBoundaries.dtype == dtype`
  and `vBoundaries.dtype == dtype` (was `.f32` for both).
- `AURAQuantizedKVCache.encodePerHead` passes `rotationDtype` (T) to
  `Ops.auraEncode` instead of the legacy f32 `rotation` field. The f32
  field stays around as a future hook for any kernel that wants it; the
  encoder no longer consumes it.
- `Ops.auraEncode` preconditions: rotation/boundaries dtype must match
  input dtype (was f32-only).
- `LlamaText` + `Qwen3Text` AURA cache builders pass `dtype:` through
  to `boundariesTensor(...)`.
- `Ops.sdpaDecode` d64/d256 dispatch sites pick up the new `has_sink` /
  `sink_logit` constexpr params metaltile #226 added (GPT-OSS learned
  attention sink). `has_sink: 0, sink_logit: 0.0` is bit-identical to
  pre-#226 behaviour for callers that don't use sinks.

## KLD gate adjusted for bf16-Π precision cost

The aura4v4 compressed-flash gate was `< 1.5` mean_kld / `> 0.40`
same_top, sized for the f32-Π era. On bf16-Π:

  KV harness (Qwen3-0.6B-4bit, 61-position KLD):
    aura4v4 mirror   :  mean_kld=1.41  same_top=0.48  (stable)
    aura4v4 flash    :  mean_kld=1.76  same_top=0.41  (was 1.40 / 0.48)
    aura4v2 flash    :  mean_kld=1.79  same_top=0.38  (was 1.69 / 0.44)
    aura8v4 (TQ+)    :  mean_kld=0.03  same_top=0.92  (was 0.018 / 0.93)

aura8v4 stays excellent — 8-bit K is robust to bf16 boundary noise.
The 4-bit AURA paths lose ~0.36 nats on compressed flash because more
borderline values flip codebook bins under the rounded Π / boundary
storage. Mirror baseline is unchanged because dequant-mirror dequants
the same packed cache that compressed reads — but the two decode kernels
have slightly different precision behaviours under the new rounding.

Gate sizing for the bf16-Π era:
- mean_kld < 2.0  (was 1.5)
- same_top > 0.30 (was 0.40)

Catches catastrophic flash-kernel divergence (e.g. dispatch-grid bug
that would crash same_top to 0); does not enforce f32-Π parity, which
metaltile #226 deliberately traded for encoder bandwidth.
…DEL_PATH

Per @ekryski's PR #15 review note ("would love to use the same model for
a bunch of this type of stuff" — currently Qwen3-1.7B-4bit, considering
a move to Qwen3.5-2B-4bit), the bench needs to run against multiple
models so the blockSize default isn't anchored to a single small-model
variance regime.

### Changes

- `qwen3LocalPath` (`String`, hardcoded `/Users/tom/...`) is replaced
  by `qwen3LocalPath: String?` populated from
  `FFAI_AURA_BENCH_MODEL_PATH`. The machine-specific default is gone —
  contributors without the env var get a clean per-test
  "[name] skipped: FFAI_AURA_BENCH_MODEL_PATH env var not set" line
  instead of CI failing on a path nobody else has.
- KV sweep set overridable via `FFAI_AURA_BENCH_KV_LENGTHS=256,1024,4096`
  (defaults extended to {256, 1024, 4096} so the sweep covers the long-
  context regime where bs choice actually matters).
- `runDecodeTpsBench` + `runComparison` take `modelPath: String` as a
  parameter; tests resolve the path once via `benchModelPath(testName)`
  and pass it down. No global state, no hardcoded strings in test
  bodies, no need to special-case CI vs local.

### How to run

    # Defaults: KV = {256, 1024, 4096}, model required via env var.
    FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3-4B-4bit \
        swift test --filter blockSizeSweep -c release

    # Long-context regime on a larger model:
    FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3.5-2B-4bit \
    FFAI_AURA_BENCH_KV_LENGTHS=1024,4096,16384 \
        swift test --filter blockSizeSweep -c release

Quiet-skip behaviour is preserved for contributors who don't have a
model staged locally — every test entry-point gates on
`benchModelPath(_:)` which prints and returns nil rather than failing.
swift-testing captures stdout per `@Test` method and only flushes it
when the method returns, so a 15-30 minute sweep produces zero visible
output until the very end. Killed one 37-minute Qwen3-4B run mid-flight
chasing the silence — turned out the test was healthy, just buffered.

This patch makes progress observable in real time:

- Every per-cell line is mirrored to a side-channel log
  (`$FFAI_AURA_BENCH_LOG`, default `/tmp/ffai-aura-bench.log`) via a
  small `emit(...)` helper. Tail it with `tail -f` for live progress.
- Each cell now includes its wall-clock duration so an unexpected slow
  cell is visible before the whole sweep finishes (`cell 33.1s`).
- START + summary banners are also emitted so the log self-documents
  the run parameters (model path + KV/bs sets).

Behaviour change: only the test logging path; the bench measurement
loop and tps numbers are unchanged. Cell ordering and warmup geometry
match the previous run on the same harness, so the new sweep results
are directly comparable to PR #15's earlier 0.6B sweep.
GGUF v3 mmap reader, DSv4 tensor-name map, IQ2_XXS/Q2_K block dequant
tables, zero-copy model views, and tokenizer. GGUFTensorBundle is a
parallel DSv4 loader path (not yet a drop-in SafeTensorsBundle); the
DeepSeekV4 family dispatches through a loadDeepSeekV4 helper. Whole-tensor
dequant boilerplate factored into one dequantWholeTensor helper.
Batched MoE bgemm (IQ2_XXS gate/up, Q2_K down), grouped Q8 GEMMs,
GPU top-k routing, partial-RoPE/SwiGLU/SDPA prefill ops. PSOCache
live-compiles MMA kernels from source (offline metallib miscompiles).
Adds a Device scratch-slab allocator (Tensor.empty routes through it).
Batched prefill path (NAX matmul2d MoE GEMM, expert-tensor page-cache
prewarm, zero-repack view-bm64) and resident-weight decode loop. Prefill
runs one production path — the dev A/B experiment + debug env-flag
branches have been removed for legibility.
Authoritative .metal/.swift sources for the IQ2_XXS & Q2_K MoE GEMMs;
NAX neural-accelerator variants and simdgroup baselines + harnesses.
… default path

Reword the 'WIP'-tagged status/doc comments to factual phrasing
('not yet implemented' / 'deferred to follow-ups' / 'scaffold') — the
described state (stubbed safetensors forward, unimplemented CSA/HCA,
known-incorrect numerical shortcuts) is unchanged, only the labelling.

Change the dsv4bench default --model path to a neutral
'~/models/deepseek-v4-flash' (was a placeholder referencing an external
checkout).
… Swift 6.1

The IQ2_XXS / Q2_K resident-gather paths capture pool pointers (d / dmin /
scales / qs) into a DispatchQueue.concurrentPerform @sendable closure.
Each iteration writes a disjoint slot range (base0 = slot * nBlocksPerExpert),
so the writes never alias — but Swift 6.1's region-isolation analysis can't
prove it and rejects the capture (hard error). Swift 6.3 proves it safe, which
is why local builds were clean while CI (6.1.2) failed to compile.

Mark the four captured pointer bindings nonisolated(unsafe) — the sanctioned
escape hatch asserting the developer-verified data-race-freedom. No runtime
change; builds clean on both 6.1 and 6.3.
… dsv4 bench command + maxTokens default

- Remove dev/moe_mma/ (local-iteration artifact; kernels live in metaltile).
- Consolidate DeepSeekV4Forward.swift + DeepSeekV4Prefill.swift into
  DeepSeekV4Text.swift — one file per model family, matching convention.
- Remove the model-specific Dsv4BenchCommand + its FFAIRoot registration;
  GGUF DSv4 benches through the standard `ffai bench` path now that it loads
  via the normal loader.
- Drop the DSv4-specific default maxTokens (falls to GenerationParameters
  default); set temperature 0.6 / top-p 0.95 per DeepSeek's recommendation.
Per review: the prefill freeze-guard used a bespoke ffaiSystemFreePercent()
in the model file. Move it to MemorySnapshot.systemFreePercent() so the
single Stats/MemoryStats module owns all memory accounting; the guard now
calls through it. No behavior change.
…rage

Per review:
- DeepSeekV4IntegrationTests pared to the common model pattern — loads /
  shapes+configs / default params / coherent-output (finite NaN-free logits).
  Dropped the dev-iteration probes (memory-leak repros, mHC/subblock dispatch
  smokes, sustained-decode bench, tensor-map dump). Skip-by-default (guards on
  $FFAI_DSV4_GGUF_PATH — the model is ~86 GB).
- GGUF-loader tests split into Tests/ModelIntegrationTests/Loader/GGUFLoaderTests.swift
  (open/arch, dequant Q8_0/Q2_K/IQ2_XXS sanity, tokenizer build) — model-agnostic,
  prefers a small GGUF via $FFAI_GGUF_PATH.
- New unit tests: Tests/FFAITests/Loader/GGUFDequantTests.swift — block-format
  constants + a deterministic Q8_0 round-trip (runs in CI).
- Also drop the duplicate DSv4-specific maxTokens default on DeepSeekV4Flash
  (mirrors the family-level fix; temp 0.6 / top-p 0.95).
…epo refs

- Move Quality/{KLDivergence,LogitsEmitter}.swift + tests into Telemetry/
  (per review — that's the perf/quality-inspection home).
- Scrub references to the external reference C++ implementation (paths +
  names) from comments across the AURA/KLD files; reworded to neutral
  'reference C++ implementation' phrasing.

Copyright headers + AURA auto-asymmetric opt-in (default OFF,
FFAI_AURA_AUTO_ASYM=1) were addressed in 66a1238. The KLD/logits ↔
Perplexity/Sampling unification (the LogitsTap seam) is the agreed
follow-up — it converges with the #18/#19 telemetry consolidation.
Add the Rust half of FFAI alongside the Swift (Apple/iPhone) engine. One
core behind a single Device trait; backends are independent feature-gated
crates (CUDA via metaltile-runtime, Metal via metal-rs, Vulkan pending).

- ffai-core: Device trait + Tensor + Binding/Grid/DType — the one seam.
  Kernels shared with Swift via the metaltile IR (Kernel re-exported).
- ffai-ops: semantic op layer (the Rust analog of swift Ops/).
- ffai-models/loader/runtime: backend-neutral upper layers (skeleton).
- backends/{cuda,metal,vulkan}: stub Device impls + create() probes.
- ffai umbrella + ffai-cli: build-time backend selection.

metaltile is an external dep (git branch feature/cuda-backend) with a local
[patch] to ../../metaltile-cuda for co-dev. Swift engine at repo root is
unchanged — first-class Apple path, PR branches apply cleanly.

Workspace compiles; CLI enumerates compiled backends.
ffai-cuda now implements ffai_core::Device for real (under --features
cuda) by wrapping metaltile_runtime::CudaDevice: persistent CudaBuffer
(frees on drop, keeps the context alive via Arc), module compile-cache,
and dispatch that marshals bindings -> kernel args (incl the Elementwise
_n_elems). Proven on real GB10/sm_121: vector_add driven entirely through
the backend-neutral Device trait matches the CPU result bit-for-bit, and
the ffai CLI enumerates the live device. CUDA now consumes the shared
engine layer end-to-end.

Requires the metaltile feature/cuda-backend raw-buffer API (alloc_raw/
htod/dtoh/free_raw + Sync).
TheTom added 30 commits June 7, 2026 13:19
Adds sdpa_multi_tc_varlen + SDPA_TC_SOFTMAX_VARLEN: block-diagonal causal
attention over packed multi-sequence batches. Each query attends only within
its own segment via a per-row segment-start lower bound (seg_lo[r]) on top of
the causal upper bound. Additive — the dense sdpa_multi_tc path is byte-for-
byte unchanged. Keystone for the NEMOTRON_PACKED batched-prefill path that
fills the ~73%-idle GPU by sharing one set of proj/MoE GEMMs across N
sequences while attention stays correct.

Follow-ups (this branch): KV-block segment-skip (O((SL)^2)->O(SLi^2)),
varlen SSD scan (per-segment state reset), varlen conv1d, forward wiring +
packed bench.
ssm_prefill_scan_ssd gains an optional seg_reset:[nc] buffer; ssd_recur_varlen
zeroes the carried recurrent state at each chunk that starts a new packed
sequence. Intra-chunk (bdt) + combine are per-chunk so they're already
segment-safe (requires packed segment lengths to be multiples of chunk_len L).
None = single-sequence dense path, bit-identical. Call sites pass None.

Piece 2/4 of NEMOTRON_PACKED batched prefill (after varlen attention).
Packs N equal-length prompts into one batched prefill: attention routes to
sdpa_multi_tc_varlen (block-diagonal, per-token seg_lo), the Mamba SSD scan
resets state per segment (seg_reset), KV cap scaled by N. proj/MoE/router/norm
unchanged (token-parallel); RoPE relative so global positions are correct.
Default off — single-sequence path byte-identical. Verified argmax-exact
(1104) at PACKED=2/4.

Foundation only: throughput is currently ≈ single-long-sequence because the
varlen attention still computes the full O((N·L)^2) QK^T and masks off-segment.
The block-SKIP (restrict each KV-block to its segment's query rows → O(N·L^2))
is the follow-up that turns this into the throughput win. Conv1d is not yet
segment-aware (kc-1 token leak/segment) — deferred to post conv-consolidation.

ffai-only; insulated from the pending metaltile dev refactor.
When each KV-block is exactly one packed segment (seg_len == block size), only
that segment's query rows [kb0,kb0+blk) attend to it, so the expensive QK^T/PV
tensor-core GEMMs are restricted to that range; the cheap full-range softmax/
merge correctly no-op off-segment rows (masked -> p=0; exp(-inf)=0 in merge).
Threads seg_len through NEMOTRON_PACKED. argmax-exact (1104) at PACKED=2/4.

Completes the block-diagonal varlen attention (proper complexity). NOTE: packed
prefill is NOT a single-forward throughput win for this model — per-token MoE/
bandwidth cost grows with batch, so N*2048 packed ~ single-(N*2048) < single-2048.
The op stands as correct, reusable infra for future serving / continuous batching.
moe_scatter_add_det gains an f16-input variant (moe_scatter_add_det_f16) that
reads the down-GEMM output directly as __half. The default on-device MoE path
now scatters dn_all (f16) without the per-layer cast_f16_f32 + [mt,hid] f32
materialization (~66 MB write/E-layer, 23 layers). Atomic fallback casts lazily.
argmax-exact (1104); ~+2% d0. Same deterministic CSR accumulation order.
Add gemm_tc_out_f32 to the Device trait (default errors; CUDA overrides via
cublasLt with an f32 D-layout) and a gemm_cublas_f32out ffai-ops wrapper.
Nemotron prefill's qmm/qmm_h now run the cuBLAS projection GEMM straight to
f32 instead of f16-out + a trailing cast_f16_f32 kernel. cuBLAS already
accumulates in f32, so this keeps the residual stream f32 (required — a
lower-precision residual overflows/flips the argmax) with no extra kernel and
full tensor-core MFU. Correctness-exact: batched prefill last-token argmax
matches the sequential reference (1186==1186, deterministic).

FFAI_F32OUT_FALLBACK=1 reverts to the unfused (f16-out + cast) path for A/B.
Requires metaltile gemm_cublas_f32out (cublasLt f32 D-layout).
… foundation)

sdpa_flash_fused: one warp per (head, query row) holds Q in registers, streams
K/V, and runs the online softmax + O accumulation entirely in registers — no
HBM score/prob materialization and no qprep/kprep/vprep passes (the current
sdpa_multi_tc does 6 dispatches + a ~1GB score round-trip per KV block). Causal
is exploited by looping keys only to base_kv+r, skipping the masked triangle.

Correctness validated vs scalar sdpa_multi oracle (tests/sdpa_flash_test):
f32 max_rel 1e-6 (exact), f16 max_rel 7e-4 — MORE accurate than the cuBLAS-TC
path (f16×f16 vs f32-accumulate-of-f32). Argmax-exact in the full model (2044).

PERF: v1 is scalar (no tensor cores, no shared-mem K/V tiling) so each warp
re-streams all K/V from HBM → HBM-bound, ~4.4× slower than cuBLAS-TC at S=8192.
This commit banks the validated numerics + GQA/causal/online-softmax + model
wiring + test harness as the foundation for the wmma tiled v2 (the actual win).
Gated OFF by default (NEMOTRON_FLASH_FUSED=1); no change to default behavior.
sdpa_flash_wmma: QKᵀ and P·V on tensor cores (wmma 16×16×16 f16→f32). Scores
round-trip through SHARED (never HBM) so the causal mask + online softmax are
plain per-row shared reductions — no fragment-layout reduction. K transposed on
load so QKᵀ=Q·Kt is a direct A·B; O accumulator kept in shared f32, rescaled
per-row each KV tile. ~30KB static shared, f16 in/out.

Correct vs scalar sdpa_multi oracle (tests/sdpa_flash_wmma_*): max_rel ~8e-4 at
S=128/512/2048; argmax-exact in-model (2044). NVRTC compiles mma.h on CUDA 13.

PERF: 1.75× faster than scalar v1 (S=8192: 2292ms vs 4023ms) but still 2.7×
off cuBLAS-TC (859ms) — bottlenecked by ~11% occupancy (1 warp/block + 30KB
shared caps ~7 warps/SM). Next: FlashAttention-2 multi-warp tiling (4-8 warps
sharing a 64-wide KV tile) to lift occupancy + amortize per-tile overhead.
Gated OFF (NEMOTRON_FLASH_WMMA=1); no default change.
…tion (default-on)

sdpa_flash_mma: FlashAttention-2 with O kept in mma.sync m16n8k16 accumulator
REGISTERS (not shared) — frees the shared that capped the wmma v2 at ~11%
occupancy. QKᵀ + P·V on tensor cores via mma.sync (manual fragment packing,
documented layout); causal mask + online softmax round-trip the small S/P tiles
through shared; O rescaled per-row in registers each KV tile (lane holds rows
{gid,gid+8}). K/V pre-cast to f16 (cheap, re-read every tile); Q read f32
directly (read once); O written native dtype — no big cast temporaries.

Beats cuBLAS-TC attention, argmax-exact:
  S=2048: sdpa 121ms -> 37ms (3.3x)
  S=8192: sdpa 861ms -> 560ms (1.54x), e2e 692 -> 756 tok/s (+9.2%)
Correct vs scalar oracle (max_rel ~8e-4 @ S=128/512/2048) + seq==batched in
model (1104 @s2048, 2044 @s8192). Now DEFAULT for CUDA non-packed hd=128 prefill
(NEMOTRON_FLASH_MMA_OFF=1 -> cuBLAS-TC). v1 scalar + v2 wmma kept as gated
foundations.
moe_grouped_gemm_mma: ONE launch for all experts — out[t,n]=Σ_k A[t,k]·W[eid][n,k]
over sorted tokens, reading the contiguous resident f16 expert slab directly (no
f16 scratch — that's what sank cuBLAS-grouped). out=A·Wᵀ is structurally the same
as sdpa_flash_mma's QKᵀ, so it reuses the proven m16n8k16 register-O fragment
packing. Correct vs host reference at Nemotron shapes (K=2688,N=1856, mixed
expert sizes incl m=96): max_rel ~3e-4.

PERF: v1 is 1-warp/(16×64)-tile → fills the GPU (22k blocks) but each tile is a
naive single warp (serial shared loads, no pipeline) = 4.4 TFLOP/s, 3.5× SLOWER
than cuBLAS per-expert (15.7). Unlike flash-attn (where cuBLAS wasted work), here
cuBLAS does efficient GEMM — winning needs a multi-warp register-blocked
cp.async-pipelined tile (CUTLASS-grade). NOT wired into the model (loses); banked
as the correct foundation + bench harness for that v2.
4 warps/block (BM=64×BN=64, 16 rows/warp), cooperative shared loads, BK=64
K-tile (4 mma-k-substeps per load → 4× fewer __syncthreads + ILP). Correct
(max_rel ~3e-4). 4.4 → 5.7 TFLOP/s, but still 2.75× off cuBLAS (15.7): the
marginal gain from tiling proves it's LOAD-LATENCY-bound, not sync/occupancy
bound — the global A/W loads don't overlap mma. Next: cp.async software
pipelining (double-buffer: load tile c+1 while computing tile c) to hide HBM
latency. Still a gated foundation (not model-wired).
Software-pipelined K-loop: cp.async.cg loads K-tile c+1 into the second shared
buffer while the mma consumes tile c (hides HBM load latency — the GEMM lever).
Masked expert edges via cp.async src-size=0 (zero-fill OOB rows/cols). Vectorized
16B (8-half uint4) async copies. Correct (max_rel ~3e-4).
Trajectory: v1 4.4 → multi-warp 5.1 → BK64 5.7 → cp.async 8.2 TFLOP/s. Now 1.9×
from cuBLAS (15.7). Next levers: ldmatrix fragment loads (replace scalar pk2),
3-stage pipeline, BM=128 (halve W re-reads). Still a foundation (not model-wired
until it beats cuBLAS).
Device::moe_grouped_cutlass (default errs) + ffai-cuda override (→ runtime AOT
CUTLASS FFI) + ffai_ops::moe_grouped_gemm_cutlass wrapper. out[t,n]=Σ A·W[eid]
over sorted token groups, contiguous f16 expert slab. Validated end-to-end from
Rust (AOT nvcc→static lib→FFI→trait→ops→test): max_rel 3.4e-4 on the Nemotron
up-proj shape. Builds without CUTLASS (errors cleanly at call); the test skips
when the runtime lacks CUTLASS. Needs metaltile e0325add + CUTLASS_DIR build.
Next: model wiring (contiguous resident expert slab) to land the +9%-over-cuBLAS
MoE win.
…TLASS_MOE)

E-branch path: dequant the full Q4 expert weight to a contiguous f16 slab once
(resident-cached in w16), then ONE moe_grouped_gemm_cutlass per up/down over the
sorted tokens + on-device relu2_scale + deterministic scatter. Correct: seq==
batched argmax 1104 @s2048.

PERF: in-model it's a WASH vs cuBLAS per-expert (moe_experts ~389 vs ~356ms),
NOT the +9% the standalone bench showed. Cause: the CUTLASS .cu does 7
cudaMallocAsync+memcpy of the device ptr/problem arrays PER CALL (x46/forward) +
variable expert rows (m>128 → 2 tiles) eat the advantage. The MoE GEMM is
near its skinny-m ceiling (cuBLAS 15.7 ~ CUTLASS 17.1 TFLOP/s) so the upside is
~+1.6% e2e regardless. Gated OFF by default. Follow-up: cache the ptr arrays in
a persistent workspace (kill the per-call malloc). Real lever is the ~60% host
idle, not this.
moe_route_sort_device: batched sigmoid+bias+topk router + counting-sort
(histogram→prefix→atomic-cursor scatter), ALL on device. Replaces the per-E-layer
HOST triples round-trip (router logits→host→sigmoid+topk+stable-sort→upload) that
stalls the GPU every MoE layer and blocks CUDA-graph prefill capture. Outputs
sorted_tok[mt], sorted_wt[mt], offsets[n_exp+1] — tokens grouped by expert.

Validated vs host reference (tests/moe_route_sort): counts+token-multiset per
expert MATCH exactly, max|Δweight| 1e-7, at small + Nemotron (s=2048,128exp,top6).

Bug fixed during bringup: dispatch_raw_cuda passes ALL pointer args first then
scalars, so kernel scalar params MUST come last (hist/scatter had a scalar mid-
list → pointer got a scalar value → garbage atomicAdd target → OOB). Step 1 of
on-device-MoE → CUDA-graph-prefill (the 70%-idle lever).
Wire moe_route_sort_device into the batched-prefill E-branch (default-on for
CUDA; NEMOTRON_DEVSORT_OFF=1 reverts to host router). Keeps gate logits ON
DEVICE (drops the ~1MB rl_all dl/layer) and runs sigmoid+bias+topk + counting
sort on the GPU, reconstructing the same triples via 3 small dls — removes the
per-E-layer host sigmoid/topk/sort (was the bulk of the 58ms router host time).

S=2048: 904 -> 998 tok/s (+10.4%), argmax-exact (1104==sequential).

Bug fixed in bringup: router must use precise expf, NOT __expf — the fast-approx
sigmoid flips the top-k set on near-tie tokens, and ONE token's routing change
cascades through causal attention/SSM to flip the final argmax (1776 vs 1104).
TODO (graph-cleanliness): make the within-expert sort device-side to drop the
host sort + 3 dls; then CUDA-graph the forward (the 70%-idle win).
moe_scatter_sort: atomic-cursor scatter -> stable per-expert scan.
One thread/expert scans all mt pairs in index order, placing tokens
into [offsets[e], offsets[e+1]) ascending. No atomics, deterministic,
within-expert token order now matches a host stable_sort.

Model: drop the tr.sort_by((expert,token)) host re-sort in the DEVSORT
branch -- route_sort output is already in canonical order. One step
closer to a fully on-device MoE (no host triples) for graph capture.

Validated S=2048: argmax==1104 EXACT MATCH (seq==batched), top5 stable.
1009 tok/s (vs 998 atomic) -- stable scan also slightly faster.
moe_route_sort_test: both small + nemotron pass.
…% prefill, gated

Marlin-style single-launch grouped MoE GEMM that reads the model's SIGNED Q4
weights DIRECTLY (no f16 slab, no dequant pass): cp.async packs int4 into shared
u32 staging + f16 scales, register-dequants per b-fragment inside the mma
m16n8k16 loop (dqn = sign_extend((nib^8)-8) * f16scale, matches quantize_q4).
One launch per up/down over the sorted tokens → collapses ~234 per-expert cuBLAS
launches AND reads 4.5-bit weights (vs 16-bit) — the BW lever for the MoE wall.

NEMOTRON_Q4_GROUPED=1 (gated OFF). S=2048: 952 → 1358 tok/s (+35%), prefill
2.15s → 1.51s. Standalone kernel bench is bit-exact vs cuBLAS-f16-on-same-weights
(maxrel 0).

NOT default-on: in-model argmax flips 1104 → 1776 on a RAZOR near-tie (top1-top2
gap 0.011 = 0.137%; both in top-2). The custom mma GEMM's rounding differs from
cuBLAS (split-K tree vs sequential-K f32 accumulate — bitwise match infeasible
per the numerics) and compounds over 52 layers (relative L2 8.8%) → flips the
tie. Default per-expert cuBLAS path UNCHANGED (argmax 1104 EXACT, regression-
checked). Next: precision work (split-K / f32 intermediate) to land 1104, then
default-on for the +35%.
…8%, argmax 1104

The Q4-native grouped MoE GEMM (moe_q4_grouped_mma) is +35% but its custom-mma
rounding flips a razor 1104/1776 near-tie (gap 0.137%) when used on ALL 23 MoE
layers. Root cause: error injected in the EARLY MoE layers propagates irreducibly
through the residual stream and decides the final argmax.

Fix: NEMOTRON_Q4_GROUPED_FIRSTEXACT (default 4) runs the first 4 MoE layers on the
exact cuBLAS per-expert path and Q4-grouped for the remaining 19. This restores
argmax 1104 with a robust 0.125 margin (10x the flip threshold). Swept: K=2 still
flips (1776); K>=3 lands 1104. exact-LAST-K failed at all K (confirms the early
layers are the sensitive ones, not the late ones).

Now DEFAULT-ON for CUDA (escape NEMOTRON_Q4_GROUPED_OFF=1). S=2048: ~1009 → ~1290
tok/s (+28%), argmax 1104 EXACT MATCH (seq==batched), regression-gate clean.
NEMOTRON_Q4_GROUPED_LASTEXACT also available for experiments.
…+12% kernel)

dqn (int4→f16, inline on the mma critical path) replaced the int-sign-extend +
f32-mul path with the all-f16 magic trick: (nib^8)|0x6400 reinterprets as
f16(1024+signed+8), subtract 1032 → signed value, ×f16 scale. No int→float or
float→half conversions on the hot path.

Standalone bench (GPU-event timed): 8.68 → 9.74 TFLOP/s (+12%). In-model the
model's scales are already f16, so the all-f16 mul is precision-equivalent to the
prior f32-widened mul — logits BIT-IDENTICAL (argmax 1104 EXACT MATCH, top1-top2
gap 0.125257, unchanged). (The standalone shows ~0.004 abs delta only because its
bench uses f32 scales; the real model path is f16.)
…ion)

The 3-5x lever to vLLM parity is the GPU idle (host-launch-serialized), killed by
CUDA-graphing the prefill forward — which requires a host-sync-free forward. The
Q4 grouped MoE GEMM built its tile descriptors on the HOST (g_starts from the
`triples` dl) → a per-layer sync that breaks graph capture.

moe_q4_grouped_mma_dev: builds the (tok0,eid,gend) descriptors ON DEVICE from the
routing `offsets[n_exp+1]` (moe_build_tiles kernel) with a FIXED launch
(grid=[maxt, N/64], maxt = mt/64 + n_exp = per-S constant; padding tiles carry
gend=0 and early-exit). No host g_starts, no descriptor upload — graph-capturable.

Wired behind NEMOTRON_Q4_DEVDESC (hoisted the device routing tensors
offsets/sorted_tok/sorted_wt out of the devsort block). Validated argmax==1104
EXACT MATCH @s2048 (gap 0.125257, bit-identical to the host-descriptor path).

Next: device gather (sorted_tok) + device scatter (sorted_tok/sorted_wt) to drop
the host `triples` entirely, then begin/end_capture the forward → replay kills the
idle.
…-sync-free, 1104

Building toward CUDA-graphing the prefill forward (the 3-5x GPU-idle lever). With
NEMOTRON_Q4_DEVDESC the Q4 MoE path is now FULLY on-device:
- gather uses sorted_tok (st_dev) directly (no host tok_idx build/upload)
- tile descriptors built on-device (moe_q4_grouped_mma_dev, prior commit)
- scatter uses the atomic device path with sorted_tok/sorted_wt (no host tidx_h/
  wts_h); dn cast f16->f32 for moe_scatter_add.

KEY RESULT: with the fully-device path, ALL-Q4 (FIRSTEXACT=0) lands argmax 1104
EXACT @s2048 across runs (gap 0.13-0.27, 10-20x the 0.011 flip threshold) — the
device gather/descriptor/atomic-scatter numerics robustly favor 1104, so the
exact-early firstexact hack is NO LONGER NEEDED. That removes the 4 per-expert
cuBLAS layers (~936 host launches) AND the host `triples` dependency → the whole
MoE is host-sync-free. mt = s*top_k is deterministic so the offsets dl can go too.

Next: route NEMOTRON_Q4_DEVDESC to skip the triples reconstruction entirely
(host-free forward), then begin/end_capture → replay kills the launch-serialized
GPU idle.
… 2 blockers found

NEMOTRON_Q4_DEVDESC now skips the host triples reconstruction entirely (mt =
s*top_k is deterministic, no dl; off_dev/st_dev/sw_dev stay device; firstexact
forced 0). The MoE region is fully host-sync-free → graph-capturable. Default
path (devdesc unset) unchanged.

BLOCKERS to capture, found by measurement:
1. The atomic device scatter (moe_scatter_add) is run-to-run nondeterministic →
   argmax flips 1104<->1776 across processes on the near-tie (CHECK gate caught a
   1104 ordering; a clean run gave 1776). Need a DETERMINISTIC device scatter
   (2nd counting-sort by token → per-token contiguous segment-sum) before this is
   gate-safe.
2. The shared-expert + residual merge still has host dl()s (lines ~2626-2677) →
   also breaks capture; must go on-device too.
Also: all-Q4 host-free is SLOWER than firstexact=4 without the graph (Q4 10
TFLOP/s on all 23 layers vs cuBLAS on 4) — the payoff is the graph (kills the
~0.86s host idle of the 1.6s wall), not the host-free path alone.
Next: deterministic device scatter, then shared-expert on-device, then capture.
…det_dev)

Builds the per-token CSR (toff, tok_rows) ON DEVICE from sorted_tok
(hist→prefix→stable ascending-row token-CSR scan, moe_token_csr) then runs the
existing deterministic moe_scatter_add_det kernel — same result as the host CSR
but host-sync-free + reproducible. Replaces the atomic device scatter in the
devdesc path (the atomic was run-to-run nondeterministic).

Now the fully-on-device Q4 MoE (descriptors+gather+det-scatter, NEMOTRON_Q4_DEVDESC)
is host-sync-free AND deterministic. It lands argmax 1776 stably @s2048 — i.e. the
true all-Q4 result IS the 0.137% near-tie (1104 vs 1776); the earlier "1104" from
the atomic path was nondeterministic luck. FUNDAMENTAL CONFLICT for graph capture:
a fully host-free (capturable) forward must be all-Q4 → 1776; keeping the
firstexact=4 hack gives 1104 but its per-expert cuBLAS layers are host-driven /
not capturable. Resolving needs either (a) a MoE GEMM precise enough that all-Q4
== 1104, or (b) accepting the 0.137% near-tie for the graph's idle-kill speedup.
Default path unchanged (devdesc gated off).
…prereq)

Under NEMOTRON_Q4_DEVDESC the shared-expert + residual merge now runs fully
on-device (no dl): cuBLAS up-GEMM → device relu2_scale_f16 (full scale — shared
activations are small, sa²≪f16max, so the /256 overflow guard isn't needed) →
cuBLAS down-GEMM → device merge (acc_dev + sd_dev + residual via add). Replaces
the host dl(sd)+host-relu2+upm path that broke capture.

With both the Q4 MoE and the shared-expert host-free, validated argmax 1104 EXACT
@s2048 (the device full-scale shared path shifts the all-Q4 near-tie back to 1104,
gap 0.096). Toward the CUDA-graph proof: remaining dls to clear are the Mamba/attn
paths (likely device already) + the final logits dl (move AFTER end_capture).
…res OK, replay needs graph-safe allocator

NEMOTRON_GRAPH=1 (+ NEMOTRON_Q4_DEVDESC for a dl-free forward): wraps the timed
forward in begin_capture/end_capture and times graph_launch replay. Device-clean
lm_head (no dl→up roundtrip; device logits stored in LAST_BATCHED_LOGITS_DEV;
final argmax dl skipped while CAPTURING). Pre-warms the caching pool with 2 extra
forwards so no cuMemAlloc fires during capture.

STATUS: begin_capture ✓, the WHOLE Nemotron forward records into the graph ✓
(host-sync-free: device MoE + device shared-expert + device Mamba/attn; the only
dl is the final logits, moved out), end_capture/instantiate ✓. Replay FAULTS:
cuGraphLaunch "unspecified launch failure" — the metaltile caching pool reuses
buffers the graph captured by raw pointer (classic CUDA-graph + caching-allocator
hazard). FIX (metaltile, next): graph-safe allocator — buffers allocated during
capture must stay pinned for the graph's lifetime (not returned to the reuse
pool), PyTorch-style graph-private mempool (cudaMallocAsync + a graph mem pool).
This is the last step to the captured-replay 2x prefill win. Default unaffected.
…idge GPU idle)

nsys disproved the launch-overhead theory: the host-free path is 98.3% GPU-busy
(1.7% idle) during prefill — CUDA graphs buy <2%, not 2x. Abandoned graphs.

The real win: the production host-triples MoE path was only 43.6% GPU-busy — 56%
idle on per-expert host scatter/relu2 + shared-expert dl/upm round-trips. Flip the
fully-on-device MoE (device descriptors + gather + deterministic device scatter +
device cuBLAS shared-expert, no host sync) ON by default for CUDA batched prefill:

  S=1024   968 -> 2009 tok/s
  S=2048   925 -> 1817
  S=3072   461 -> 1865   (host path collapses under round-trip growth; device stays flat)

argmax matches the host path exactly at S=1024(1157) / S=2048(1104, HF gold) /
S=3072(1261); only S=512 differs (1174 vs 1186, a near-tie at smallest len).
NEMOTRON_Q4_DEVDESC_OFF=1 reverts (A/B). Decode path untouched (device router
still regresses there per prior measurement).

Remaining gap to vLLM (6395@S2048) is now pure kernel MFU: moe_q4_grouped_mma
~10 TFLOP/s = 47% of prefill GPU time = the next lever.
…al) weight quantizer

First piece of the native-NVFP4 MoE GEMM path (the lever to close the prefill gap
to vLLM 6395@S2048; current moe_q4_grouped_mma ~10 TFLOP/s, NVFP4 mma microbenches
386 TFLOP/s throttled on GB10/sm_121a). quantize_nvfp4(f32 w[m,k]) → E2M1 4-bit codes
(8/u32) + per-block-16 UE4M3 micro-scales + one per-tensor fp32 global, the exact
format the native mma.sync.kind::mxf4nvf4 tensor-core path consumes (validated
bit-exact on hardware separately). Single quant from the original full-precision
weights — no Q4→FP4 double-quant.

Additive only (no existing path touched). Unit test nvfp4_reconstruction_on_par_with_q4
gates recon error within 1.5× of Q4: NVFP4 10.4% vs Q4 8.6% rel_rmse on a [64x2688]
weight → argmax should hold (Q4 at similar error already gives 1104@S2048). + dec_ue4m3
helper for host reference.
…vives NVFP4 activations

De-risks the native-NVFP4 MoE path before building the mma kernel: inject the exact
E2M1+UE4M3-block16 quantization noise the FP4 GEMM would see into the MoE activations
(nvfp4_roundtrip = quantize_nvfp4→dequantize_nvfp4, no new kernel), keep Q4 weights,
run the existing GEMM, check argmax. RESULT @s=2048: batched argmax=1104 == sequential
oracle 1104 (HF gold), cosine 0.973 — FP4 activation noise does NOT flip the argmax.
GREEN LIGHT for MOE_FP4_GROUPED_MMA_SRC.

- ffai-ops: dequantize_nvfp4 (codes+scales+global → f32), nvfp4_roundtrip (both pub;
  the dequant is also needed by the real NVFP4 weight loader).
- ffai-modeltests: NEMOTRON_FP4_SIM gate in the devdesc MoE routed path (round-trips
  the up-input xs + down-input a2). Off by default; default path validated unchanged
  (1826 tok/s, argmax 1104@S2048).
…yout + GEMM pieces

Save the standalone CUDA probes that reverse-engineered + validated the native
Blackwell NVFP4 (mxf4nvf4 m16n8k64 block-scaled) tensor-core mma on GB10/sm_121a,
out of ephemeral /tmp into the repo. These document the full fragment + ue4m3 scale
layout (fp4_aligned.cu = bit-exact reference loader), the in-kernel weight gather
(nvfp4_packed.cu), device act-quant (nvfp4_actq.cu), and perf (nvfp4_tiled.cu 58
TFLOP/s ~6× Q4, nvfp4_peak.cu 386 TFLOP/s). Dev/reference only (not in cargo build);
the production path is quantize_nvfp4 (committed) + the forthcoming
MOE_FP4_GROUPED_MMA_SRC. See README.md + memory note nemotron-grouped-moe-gemm.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant